Skip to content

[AutoDiff] Support differentiation of branching cast instructions. #32069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 29, 2020

Conversation

dan-zheng
Copy link
Contributor

Support differentiation of is and as? operators.

These operators lower to branching cast SIL instructions, requiring control flow differentiation support:

  • checked_cast_br
  • checked_cast_value_br
  • checked_cast_addr_br

Resolves SR-12898.


Examples:

import _Differentiation

// checked_cast_br
func typeCheckOperator<T>(_ x: Float, _ metatype: T.Type) -> Float {
  if metatype is Int.Type {
    return x + x
  }
  return x * x
}
print(valueWithGradient(at: 3, in: { typeCheckOperator($0, Int.self) }))
// (value: 6.0, gradient: 2.0)
print(valueWithGradient(at: 3, in: { typeCheckOperator($0, Float.self) }))
// (value: 9.0, gradient: 6.0)

// checked_cast_addr_br
func conditionalCast<T: Differentiable>(_ x: T) -> T {
  if let _ = x as? Float {
    // Without enum differentiation support (TF-1004), using `y: Float?` value
    // produces a non-differentiability error.
  }
  return x
}
print(valueWithGradient(at: Float(3), in: conditionalCast))
// (value: 3.0, gradient: 1.0)

dan-zheng added 3 commits May 28, 2020 12:32
Add a common helper function `VJPEmitter::createTrampolineBasicBlock`.

Change `VJPEmitter::buildPullbackValueStructValue` to take an original basic
block instead of a terminator instruction.
This test was disabled in SR-12741 due to iphonesimulator-i386 failures.
Enabling the test on other platforms is important to prevent regressions.
Support differentiation of `is` and `as?` operators.

These operators lower to branching cast SIL instructions, requiring control
flow differentiation support.

Resolves SR-12898.
@dan-zheng dan-zheng requested review from rxwei and marcrasi May 28, 2020 19:58
getOpASTType(ccabi->getTargetFormalType()),
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getSuccessBB()),
createTrampolineBasicBlock(ccabi, pbStructVal, ccabi->getFailureBB()),
ccabi->getTrueBBCount(), ccabi->getFalseBBCount());
Copy link
Contributor Author

@dan-zheng dan-zheng May 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: these newly added VJPEmitter visitors significantly duplicate code from SILCloner.

We can potentially just inherit SILCloner visitors by baking createTrampolineBasicBlock logic into VJPEmitter::remapBasicBlock. This changes the meaning of VJPEmitter::getOpBasicBlock used by other visitors though, so I'm not sure it would work.

Comment on lines +195 to +197
// CHECK-LABEL: sil hidden [ossa] @${{.*}}checked_cast_addr_nonactive_result{{.*}} : $@convention(thin) <T where T : Differentiable> (@in_guaranteed T) -> @out T {
// CHECK: checked_cast_addr_br take_always T in %3 : $*T to Float in %5 : $*Float, bb1, bb2
// CHECK: }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I didn't add a test for checked_cast_value_br because I'm not sure what Swift codes lowers to it. The as? operators I tried all lowered to checked_cast_addr_br.

The added control flow differentiation support is generic so I'm pretty sure checked_cast_value_br differentiation works. We can add a test if we encounter a case.

Copy link
Contributor

@rxwei rxwei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. Thanks!

@dan-zheng
Copy link
Contributor Author

@swift-ci Please smoke test

@dan-zheng
Copy link
Contributor Author

@swift-ci Please smoke test Linux

@dan-zheng dan-zheng merged commit c63153d into swiftlang:master May 29, 2020
@dan-zheng dan-zheng deleted the autodiff-branching-casts branch May 29, 2020 08:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants